{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "_mas4n0LooRZ"
      },
      "outputs": [],
      "source": [
        "from transformers import DetrFeatureExtractor, DetrForSegmentation\n",
        "\n",
        "processor = DetrFeatureExtractor.from_pretrained(\"facebook/detr-resnet-50-panoptic\")\n",
        "model = DetrForSegmentation.from_pretrained(\"facebook/detr-resnet-50-panoptic\")\n"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "import io\n",
        "import requests\n",
        "from PIL import Image\n",
        "import torch\n",
        "import numpy\n",
        "\n",
        "url = \"https://farm4.staticflickr.com/3487/3925656789_1b64654c91_z.jpg\"\n",
        "image = Image.open(requests.get(url, stream=True).raw)\n"
      ],
      "metadata": {
        "id": "d0srchyLs7Gq"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "inputs = processor(images=image, return_tensors=\"pt\")\n",
        "outputs = model(**inputs)"
      ],
      "metadata": {
        "id": "TqZ6xPkZs-TB"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "processed_sizes = torch.as_tensor(inputs[\"pixel_values\"].shape[-2:]).unsqueeze(0)\n",
        "result = feature_extractor.post_process_panoptic(outputs, processed_sizes)[0]\n"
      ],
      "metadata": {
        "id": "fNR2NP_stGkz"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "import itertools\n",
        "import io\n",
        "import seaborn as sns\n",
        "import numpy\n",
        "from transformers.image_transforms import rgb_to_id, id_to_rgb\n",
        "from matplotlib import pyplot as plt\n",
        "\n",
        "palette = itertools.cycle(sns.color_palette())\n",
        "panoptic_seg = Image.open(io.BytesIO(result[\"png_string\"]))\n",
        "panoptic_seg = numpy.array(panoptic_seg, dtype=numpy.uint8).copy()\n",
        "panoptic_seg_id = rgb_to_id(panoptic_seg)\n",
        "panoptic_seg[:, :, :] = 0\n",
        "for id in range(panoptic_seg_id.max() + 1):\n",
        "    panoptic_seg[panoptic_seg_id == id] = numpy.asarray(next(palette)) * 255\n",
        "plt.figure(figsize=(15, 15))\n",
        "plt.imshow(panoptic_seg)\n",
        "plt.axis(\"off\")\n",
        "plt.show()\n"
      ],
      "metadata": {
        "id": "JifQMwbjtLOg"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "nTgrbcchtSii"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}